# Fairness Metrics

A key insight that makes this library so flexible with respect to how fairness is defined, is that many fairness definitions simply compare a statistic between groups.

In this notebook, we express many well-known fairness definitions in terms of statistics through the use of the `LinearFractionalParity` class. The class is implemented as a `torchmetrics` Metric, which allows it to be integrated into any training or evaluation loop in Pytorch. All that it needs is a statistic to compare.

To start, make sure the library is installed:

In [1]:
import torchmetrics

If it isn't, you can install it with:

``pip install torchmetrics``

In [2]:
from fairret.metric import LinearFractionalParity

The most simple statistic to compare is the positive rate, i.e. the rate at which positive predictions are made for each sensitive feature. Equality in these positive rates is typically referred to as *demographic parity* (a.k.a. *statistical parity* or *equal acceptance rate*).

Hence, we can express the extent to which demographic parity holds by passing the `PositiveRate` statistic to a `LinearFractionalParity` metric:

In [3]:
from fairret.statistic import PositiveRate
demographic_parity = LinearFractionalParity(PositiveRate())

TypeError: LinearFractionalParity.__init__() missing 1 required positional argument: 'stat_shape'

Oops, we also need to provide the shape of the statistic. Most often, this will simply be the number of sensitive features. For example, the `PositiveRate` statistic just computes a single value for each sensitive feature. 

To know the number of sensitive features, we need to define some data first:

In [4]:
import torch
torch.manual_seed(0)

feat = torch.tensor([[1., 2.], [3., 4.], [5., 6.], [7., 8.]])
sens = torch.tensor([[1., 0.], [1., 0.], [0., 1.], [0., 1.]])
label = torch.tensor([[0.], [1.], [0.], [1.]])

n_sensitive_features = sens.shape[1]
print(f"Number of sensitive features: {n_sensitive_features}")

Number of sensitive features: 2


We can now construct the metric:

In [5]:
demographic_parity = LinearFractionalParity(PositiveRate(), stat_shape=(n_sensitive_features,))

`LinearFractionalParity` follows the exact interface as all other `Metric` classes in `torchmetrics`.

For all details of this interface, check out the `torchmetrics` [documentation](https://lightning.ai/docs/torchmetrics/stable/pages/quickstart.html#module-metrics).

Basically, it follows a three-step approach:
1. Call `metric.update(args)`, where `args` in our case are the arguments necessary to compute the statistic.
2. Call `metric.compute()`, which returns violation of the fairness definition.
3. Call `metric.reset()`, which resets the initial state of the metric.

If you need to use any other `torchmetrics` settings, such as `compute_with_cache`, you can pass them as keyword arguments to the `LinearFractionalParity` class upon initialization.

## Example

Let's train a model and keep track of the demographic parity.

Without fairret:

In [6]:
h_layer_dim = 16
lr = 1e-3
batch_size = 1024

torch.manual_seed(0)

def build_model():
 model = torch.nn.Sequential(
 torch.nn.Linear(feat.shape[1], h_layer_dim),
 torch.nn.ReLU(),
 torch.nn.Linear(h_layer_dim, 1)
 )
 optimizer = torch.optim.Adam(model.parameters(), lr=lr)
 return model, optimizer

from torch.utils.data import TensorDataset, DataLoader
dataset = TensorDataset(feat, sens, label)
dataloader = DataLoader(dataset, batch_size=batch_size)

In [7]:
import numpy as np

nb_epochs = 100
model, optimizer = build_model()
for epoch in range(nb_epochs):
 losses = []
 for batch_feat, batch_sens, batch_label in dataloader:
 optimizer.zero_grad()
 
 logit = model(batch_feat)
 loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
 loss.backward()
 
 pred = torch.sigmoid(logit)
 demographic_parity.update(pred, batch_sens)
 
 optimizer.step()
 losses.append(loss.item())
 dp_for_epoch = demographic_parity.compute()
 demographic_parity.reset()
 print(f"Epoch: {epoch}, loss: {np.mean(losses)}, dp: {dp_for_epoch}")

Epoch: 0, loss: 0.7091795206069946, dp: 0.03126168251037598
Epoch: 1, loss: 0.7061765193939209, dp: 0.02563762664794922
Epoch: 2, loss: 0.7033581733703613, dp: 0.020147204399108887
Epoch: 3, loss: 0.7007156610488892, dp: 0.014800786972045898
Epoch: 4, loss: 0.6982340812683105, dp: 0.009598910808563232
Epoch: 5, loss: 0.6959078907966614, dp: 0.00453948974609375
Epoch: 6, loss: 0.6937355995178223, dp: 0.00037485361099243164
Epoch: 7, loss: 0.6917158365249634, dp: 0.005139470100402832
Epoch: 8, loss: 0.6898466944694519, dp: 0.009749293327331543
Epoch: 9, loss: 0.6881252527236938, dp: 0.014199256896972656
Epoch: 10, loss: 0.6865478754043579, dp: 0.01848423480987549
Epoch: 11, loss: 0.6851094961166382, dp: 0.022599458694458008
Epoch: 12, loss: 0.6838041543960571, dp: 0.02654099464416504
Epoch: 13, loss: 0.6826250553131104, dp: 0.030305147171020508
Epoch: 14, loss: 0.6815641522407532, dp: 0.03388887643814087
Epoch: 15, loss: 0.6806124448776245, dp: 0.037290215492248535
Epoch: 16, loss: 0.679

With fairret:

In [8]:
from fairret.loss import NormLoss
fairness_strength = 1
norm_loss = NormLoss(PositiveRate())

model, optimizer = build_model()
for epoch in range(nb_epochs):
 losses = []
 for batch_feat, batch_sens, batch_label in dataloader:
 optimizer.zero_grad()
 
 logit = model(batch_feat)
 loss = torch.nn.functional.binary_cross_entropy_with_logits(logit, batch_label)
 loss += fairness_strength * norm_loss(logit, batch_sens)
 loss.backward()
 
 pred = torch.sigmoid(logit)
 demographic_parity.update(pred, batch_sens)
 
 optimizer.step()
 losses.append(loss.item())
 dp_for_epoch = demographic_parity.compute()
 demographic_parity.reset()
 print(f"Epoch: {epoch}, loss: {np.mean(losses)}, dp: {dp_for_epoch}")

Epoch: 0, loss: 0.7429860830307007, dp: 0.0319594144821167
Epoch: 1, loss: 0.7368547916412354, dp: 0.02823483943939209
Epoch: 2, loss: 0.7306909561157227, dp: 0.024389266967773438
Epoch: 3, loss: 0.7244950532913208, dp: 0.020473718643188477
Epoch: 4, loss: 0.7182670831680298, dp: 0.016491293907165527
Epoch: 5, loss: 0.7120081186294556, dp: 0.012444257736206055
Epoch: 6, loss: 0.7057176828384399, dp: 0.008333325386047363
Epoch: 7, loss: 0.6993964910507202, dp: 0.004159450531005859
Epoch: 8, loss: 0.6933507323265076, dp: 7.653236389160156e-05
Epoch: 9, loss: 0.6987630128860474, dp: 0.0022329092025756836
Epoch: 10, loss: 0.700627326965332, dp: 0.0029762983322143555
Epoch: 11, loss: 0.6999425292015076, dp: 0.0027066469192504883
Epoch: 12, loss: 0.697374701499939, dp: 0.0016865730285644531
Epoch: 13, loss: 0.693393886089325, dp: 9.85860824584961e-05
Epoch: 14, loss: 0.6960413455963135, dp: 0.0019237995147705078
Epoch: 15, loss: 0.6981306076049805, dp: 0.003303050994873047
Epoch: 16, loss: 0

## Clarification on the exact value of the metric

Though it is generally agreed upon that demographic parity is achieved when the positive rates are equal, there is a lot of ambiguity on how to measure the extent to which demographic parity is violated. In `fairret`, we assess this violation by comparing the statistic values for each sensitive feature to the statistic value for the entire population. 

In our specific example here, it just means that we compare the mean prediction value for each of the two groups to the overall mean prediction value.

There is then one more ingredient: how the gap between this value is actually computed. We provide a few options, such as the absolute difference (`gap_abs_max`) and the relative absolute difference (`gap_relative_abs_max`). The former takes the maximum of the L1 norm of the gap, while the latter divides this maximum by the overall mean statistic. This is the default behavior of the `LinearFractionalParity` class.

To use another gap function, simply pass it as an argument to the `LinearFractionalParity` class. Of course, you're also free to implement your own!

## What's next?

In larger pipelines, you will likely want to define separate metrics for train, validation, and test set results. 

Also, you may want to assess many fairness definitions at once. These could all be defined as separate metrics, or you could make use of the `StackedLinearFractionalStatistic`, which keeps track of many statistics at the same time (see [Stacked Statistic.ipynb](./Stacked Statistic.ipynb)). However, keep in mind that you then won't get scalar values out of the `compute` method, but a tensor that stacks the violations of all statistics.